c90681
@@ -1,11 +1,11 @@
 /*
- * Copyright 2005-2010 the original author or authors.
+ * Copyright 2005-2011 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * You may obtain a copy of the License at
  *
- *      http://www.apache.org/licenses/LICENSE-2.0
+ *     http://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS,
@@ -17,16 +17,15 @@
 package org.springframework.ws.transport.http;
 
 import java.util.Map;
-import javax.servlet.ServletException;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 
-import org.springframework.beans.BeansException;
 import org.springframework.beans.factory.BeanFactoryUtils;
 import org.springframework.beans.factory.BeanInitializationException;
 import org.springframework.beans.factory.BeanNameAware;
 import org.springframework.beans.factory.InitializingBean;
 import org.springframework.beans.factory.NoSuchBeanDefinitionException;
+import org.springframework.context.ApplicationContext;
 import org.springframework.web.servlet.DispatcherServlet;
 import org.springframework.web.servlet.FrameworkServlet;
 import org.springframework.web.util.WebUtils;
@@ -222,14 +221,12 @@
public class MessageDispatcherServlet extends FrameworkServlet {
         messageReceiverHandlerAdapter.handle(httpServletRequest, httpServletResponse, messageReceiver);
     }
 
+    /**
+     * This implementation calls {@link #initStrategies}.
+     */
     @Override
-    protected void initFrameworkServlet() throws ServletException, BeansException {
-        initMessageReceiverHandlerAdapter();
-        initWsdlDefinitionHandlerAdapter();
-        initXsdSchemaHandlerAdapter();
-        initMessageReceiver();
-        initWsdlDefinitions();
-        initXsdSchemas();
+    protected void onRefresh(ApplicationContext context) {
+        initStrategies(context);
     }
 
     @Override
@@ -264,7 +261,7 @@
public class MessageDispatcherServlet extends FrameworkServlet {
         if (HttpTransportConstants.METHOD_GET.equals(request.getMethod()) &&
                 request.getRequestURI().endsWith(WSDL_SUFFIX_NAME)) {
             String fileName = WebUtils.extractFilenameFromUrlPath(request.getRequestURI());
-            return (WsdlDefinition) wsdlDefinitions.get(fileName);
+            return wsdlDefinitions.get(fileName);
         }
         else {
             return null;
@@ -285,24 +282,37 @@
public class MessageDispatcherServlet extends FrameworkServlet {
         if (HttpTransportConstants.METHOD_GET.equals(request.getMethod()) &&
                 request.getRequestURI().endsWith(XSD_SUFFIX_NAME)) {
             String fileName = WebUtils.extractFilenameFromUrlPath(request.getRequestURI());
-            return (XsdSchema) xsdSchemas.get(fileName);
+            return xsdSchemas.get(fileName);
         }
         else {
             return null;
         }
     }
 
-    private void initMessageReceiverHandlerAdapter() {
+    /**
+     * Initialize the strategy objects that this servlet uses.
+     * <p>May be overridden in subclasses in order to initialize further strategy objects.
+     */
+    protected void initStrategies(ApplicationContext context) {
+        initMessageReceiverHandlerAdapter(context);
+        initWsdlDefinitionHandlerAdapter(context);
+        initXsdSchemaHandlerAdapter(context);
+        initMessageReceiver(context);
+        initWsdlDefinitions(context);
+        initXsdSchemas(context);
+    }
+
+
+    private void initMessageReceiverHandlerAdapter(ApplicationContext context) {
         try {
             try {
-                messageReceiverHandlerAdapter = (WebServiceMessageReceiverHandlerAdapter) getWebApplicationContext()
-                        .getBean(getMessageReceiverHandlerAdapterBeanName(),
-                                WebServiceMessageReceiverHandlerAdapter.class);
+                messageReceiverHandlerAdapter = context.getBean(getMessageReceiverHandlerAdapterBeanName(),
+                        WebServiceMessageReceiverHandlerAdapter.class);
             }
             catch (NoSuchBeanDefinitionException ignored) {
                 messageReceiverHandlerAdapter = new WebServiceMessageReceiverHandlerAdapter();
             }
-            initWebServiceMessageFactory();
+            initWebServiceMessageFactory(context);
             messageReceiverHandlerAdapter.afterPropertiesSet();
         }
         catch (Exception ex) {
@@ -310,15 +320,14 @@
public class MessageDispatcherServlet extends FrameworkServlet {
         }
     }
 
-    private void initWebServiceMessageFactory() {
+    private void initWebServiceMessageFactory(ApplicationContext context) {
         WebServiceMessageFactory messageFactory;
         try {
-            messageFactory = (WebServiceMessageFactory) getWebApplicationContext()
-                    .getBean(getMessageFactoryBeanName(), WebServiceMessageFactory.class);
+            messageFactory = context.getBean(getMessageFactoryBeanName(), WebServiceMessageFactory.class);
         }
         catch (NoSuchBeanDefinitionException ignored) {
-            messageFactory = (WebServiceMessageFactory) defaultStrategiesHelper
-                    .getDefaultStrategy(WebServiceMessageFactory.class, getWebApplicationContext());
+            messageFactory = defaultStrategiesHelper
+                    .getDefaultStrategy(WebServiceMessageFactory.class, context);
             if (logger.isDebugEnabled()) {
                 logger.debug("No WebServiceMessageFactory found in servlet '" + getServletName() + "': using default");
             }
@@ -326,11 +335,11 @@
public class MessageDispatcherServlet extends FrameworkServlet {
         messageReceiverHandlerAdapter.setMessageFactory(messageFactory);
     }
 
-    private void initWsdlDefinitionHandlerAdapter() {
+    private void initWsdlDefinitionHandlerAdapter(ApplicationContext context) {
         try {
             try {
-                wsdlDefinitionHandlerAdapter = (WsdlDefinitionHandlerAdapter) getWebApplicationContext()
-                        .getBean(getWsdlDefinitionHandlerAdapterBeanName(), WsdlDefinitionHandlerAdapter.class);
+                wsdlDefinitionHandlerAdapter =
+                        context.getBean(getWsdlDefinitionHandlerAdapterBeanName(), WsdlDefinitionHandlerAdapter.class);
 
             }
             catch (NoSuchBeanDefinitionException ignored) {
@@ -344,10 +353,10 @@
public class MessageDispatcherServlet extends FrameworkServlet {
         }
     }
 
-    private void initXsdSchemaHandlerAdapter() {
+    private void initXsdSchemaHandlerAdapter(ApplicationContext context) {
         try {
             try {
-                xsdSchemaHandlerAdapter = (XsdSchemaHandlerAdapter) getWebApplicationContext()
+                xsdSchemaHandlerAdapter = context
                         .getBean(getXsdSchemaHandlerAdapterBeanName(), XsdSchemaHandlerAdapter.class);
 
             }
@@ -363,14 +372,13 @@
public class MessageDispatcherServlet extends FrameworkServlet {
         }
     }
 
-    private void initMessageReceiver() {
+    private void initMessageReceiver(ApplicationContext context) {
         try {
-            messageReceiver = (WebServiceMessageReceiver) getWebApplicationContext()
-                    .getBean(getMessageReceiverBeanName(), WebServiceMessageReceiver.class);
+            messageReceiver = context.getBean(getMessageReceiverBeanName(), WebServiceMessageReceiver.class);
         }
         catch (NoSuchBeanDefinitionException ex) {
-            messageReceiver = (WebServiceMessageReceiver) defaultStrategiesHelper
-                    .getDefaultStrategy(WebServiceMessageReceiver.class, getWebApplicationContext());
+            messageReceiver = defaultStrategiesHelper
+                    .getDefaultStrategy(WebServiceMessageReceiver.class, context);
             if (messageReceiver instanceof BeanNameAware) {
                 ((BeanNameAware) messageReceiver).setBeanName(getServletName());
             }
@@ -380,10 +388,9 @@
public class MessageDispatcherServlet extends FrameworkServlet {
         }
     }
 
-    /** Find all {@link WsdlDefinition WsdlDefinitions} in the ApplicationContext, incuding ancestor contexts. */
-    private void initWsdlDefinitions() {
+    private void initWsdlDefinitions(ApplicationContext context) {
         wsdlDefinitions = BeanFactoryUtils
-                .beansOfTypeIncludingAncestors(getWebApplicationContext(), WsdlDefinition.class, true, false);
+                .beansOfTypeIncludingAncestors(context, WsdlDefinition.class, true, false);
         if (logger.isDebugEnabled()) {
             for (Map.Entry<String, WsdlDefinition> entry : wsdlDefinitions.entrySet()) {
                 String beanName = entry.getKey();
@@ -393,10 +400,9 @@
public class MessageDispatcherServlet extends FrameworkServlet {
         }
     }
 
-    /** Find all {@link XsdSchema} in the ApplicationContext, incuding ancestor contexts. */
-    private void initXsdSchemas() {
+    private void initXsdSchemas(ApplicationContext context) {
         xsdSchemas = BeanFactoryUtils
-                .beansOfTypeIncludingAncestors(getWebApplicationContext(), XsdSchema.class, true, false);
+                .beansOfTypeIncludingAncestors(context, XsdSchema.class, true, false);
         if (logger.isDebugEnabled()) {
             for (Map.Entry<String, XsdSchema> entry : xsdSchemas.entrySet()) {
                 String beanName = entry.getKey();
